#include "../coroutine.h"
#include "../coro_call.h"
#include "unittest.hh"
#include <iostream>
#include <thread>

#include <memory>

FIXTURE_BEGIN(coro)

CASE(TestSpawn1) {
    auto c1 = coro_spawn(nullptr);
    ASSERT_TRUE(c1 != coroutine::INVALID_CORO_ID);
    int i = 1;
    bool run = false;
    auto c2 = coro_spawn([&](void* ptr){
        ASSERT_TRUE(ptr == &i);
        run = true;
    }, &i);
    ASSERT_TRUE(c2 != coroutine::INVALID_CORO_ID);
    coro_resume(c2);
    ASSERT_TRUE(run);
    coro_close_all();
}

CASE(TestSpawn2) {
    int i = 1;
    bool run1 = false;
    bool run2 = false;
    bool afterSpawn = false;
    auto c1 = coro_spawn([&](void* ptr){
        ASSERT_TRUE(ptr == &i);
        run1 = true;
        auto c2 = coro_spawn([&](void* p){
            run2 = true;
        });
        coro_resume(c2);
        afterSpawn = true;
    }, &i);
    ASSERT_TRUE(c1 != coroutine::INVALID_CORO_ID);
    coro_resume(c1);
    ASSERT_TRUE(run1);
    ASSERT_TRUE(run2);
    coro_resume(c1);
    ASSERT_TRUE(afterSpawn);
    coro_close_all();
}

CASE(TestSpawn3) {
    int i = 1;
    bool run1 = false;
    bool run2 = false;
    bool afterSpawn = false;
    auto c1 = coro_spawn([&](void* ptr){
        ASSERT_TRUE(ptr == &i);
        auto id = coro_id();
        run1 = true;
        auto c2 = coro_spawn([&](void* p){
            coro_push(id,"push");
            run2 = true;
        });
        afterSpawn = true;
        auto pop = coro_pop_wait();
        coro_resume(c2);
    }, &i);
    ASSERT_TRUE(c1 != coroutine::INVALID_CORO_ID);
    coro_sched();
    coro_sched();
    coro_sched();
    ASSERT_TRUE(run1);
    ASSERT_TRUE(run2);
    coro_sched();
    ASSERT_TRUE(afterSpawn);
    coro_close_all();
}

CASE(TestStart1) {
    auto c1 = coro_start(nullptr);
    ASSERT_TRUE(c1 != coroutine::INVALID_CORO_ID);
    int i = 1;
    bool run = false;
    auto c2 = coro_start([&](void* ptr){
        ASSERT_TRUE(ptr == &i);
        run = true;
    }, &i);
    ASSERT_TRUE(c2 != coroutine::INVALID_CORO_ID);
    ASSERT_TRUE(run);
    coro_close_all();
}

CASE(TestStart2) {
    int i = 1;
    bool run1 = false;
    bool run2 = false;
    bool afterSpawn = false;
    auto c1 = coro_start([&](void* ptr){
        ASSERT_TRUE(ptr == &i);
        run1 = true;
        auto c2 = coro_start([&](void* p){
            run2 = true;
        });
        afterSpawn = true;
    }, &i);
    ASSERT_TRUE(c1 != coroutine::INVALID_CORO_ID);
    ASSERT_TRUE(run1);
    ASSERT_TRUE(run2);
    coro_resume(c1);
    ASSERT_TRUE(afterSpawn);
    coro_close_all();
}

CASE(TestYield1) {
    bool afterYield = false;
    auto c = coro_start([&](void* ptr){
        coro_yield();
        afterYield = true;
    });
    ASSERT_TRUE(c != coroutine::INVALID_CORO_ID);
    coro_resume(c);
    ASSERT_TRUE(afterYield);
    coro_close_all();
}

CASE(TestYield2) {
    bool run = false;
    auto c = coro_spawn([&](void* ptr){
        run = true;
    });
    ASSERT_TRUE(c != coroutine::INVALID_CORO_ID);
    coro_yield();
    ASSERT_TRUE(run);
    coro_close_all();
}

CASE(TestYield3) {
    coro_yield();
    coro_close_all();
}

CASE(TestSleep) {
    bool afterSleep = false;
    auto c = coro_start([&](void* ptr){
        coro_sleep(100);
        afterSleep = true;
    });
    ASSERT_TRUE(c != coroutine::INVALID_CORO_ID);
    std::this_thread::sleep_for(std::chrono::duration(std::chrono::milliseconds(101)));
    coro_sched();
    ASSERT_TRUE(afterSleep);
    coro_close_all();
}

CASE(TestClose) {
    auto c = coro_spawn([&](void* ptr){
    });
    coro_close(c);
    coro_resume(c);
    coro_close_all();
}

CASE(TestPush1) {
    auto i = 0;
    auto c = coro_spawn([&](void* ptr){
        i = std::any_cast<int>(coro_pop());
    });
    coro_push(c, 1);
    coro_resume(c);
    ASSERT_TRUE(i == 1);
    coro_close_all();
}

CASE(TestPush2) {
    auto i = 0;
    auto c = coro_spawn([&](void* ptr){
        i = std::any_cast<int>(coro_pop());
    });
    coro_resume(c, 1);
    ASSERT_TRUE(i == 1);
    coro_close_all();
}

CASE(TestPush3) {
    auto i = 0;
    auto c = coro_start([&](void* ptr){
        coro_yield();
        i = std::any_cast<int>(coro_pop());
    });
    coro_resume(c, 1);
    ASSERT_TRUE(i == 1);
    coro_close_all();
}

CASE(TestPop1) {
    bool nothing = true;
    auto c = coro_start([&](void* ptr){
        auto i = coro_pop();
        if (i) {
            nothing = false;
        }
    });
    ASSERT_TRUE(nothing);
    coro_close_all();
}

CASE(TestPopWait1) {
    bool nothing = true;
    bool run = false;
    auto c = coro_start([&](void* ptr){
        auto i = coro_pop_wait();
        if (i) {
            nothing = false;
        }
        run = true;
    });
    coro_resume(c);
    ASSERT_TRUE(nothing && run);
    coro_close_all();
}

CASE(TestPopWait2) {
    bool nothing = true;
    bool run = false;
    auto c = coro_start([&](void* ptr){
        auto i = coro_pop_wait(500);
        if (i) {
            nothing = false;
        }
        run = true;
    });
    std::this_thread::sleep_for(std::chrono::duration(std::chrono::milliseconds(200)));
    coro_resume(c);
    ASSERT_TRUE(nothing && run);
    coro_close_all();
}

CASE(TestPopWait3) {
    bool nothing = true;
    bool run = false;
    auto c = coro_start([&](void* ptr){
        auto i = coro_pop_wait(500);
        if (i) {
            nothing = false;
        }
        run = true;
    });
    std::this_thread::sleep_for(std::chrono::duration(std::chrono::milliseconds(200)));
    coro_resume(c, 100);
    ASSERT_TRUE(!nothing && run);
    coro_close_all();
}

CASE(TestQuit1) {
    bool run = false;
    auto c = coro_start([&](void* ptr){
        coro_quit();
        run = true;
    });
    ASSERT_TRUE(!run);
    coro_close_all();
}

CASE(TestCoroCall1) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    int j = 0;
    coroutine::CallID id;
    bool afterCall = false;
    auto c = coro_start([&](void* ptr) {
        sender.callNoParam(j);
        afterCall = true;
    });
    sender.receive(2);
    ASSERT_TRUE(j == 2);
    ASSERT_TRUE(afterCall);
    coro_close_all();
}

CASE(TestCoroCall2) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    int j = 0;
    coroutine::CallID id;
    bool afterCall = false;
    coroutine::CallResult result = coroutine::CallResult::NO_RESULT;
    auto c = coro_start([&](void* ptr) {
        result = sender.callNoParam(j);
        afterCall = true;
    });
    sender.error(coroutine::ERROR_FAIL);
    ASSERT_TRUE(result == coroutine::CallResult::FAIL);
    ASSERT_TRUE(j == 0);
    ASSERT_TRUE(afterCall);
    coro_close_all();
}

CASE(TestCoroCall3) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    auto c = coro_start([&](void*) {
        sender.call();
        afterCall = true;
    });
    sender.success();
    ASSERT_TRUE(afterCall);
    coro_close_all();
}

CASE(TestCoroCall4) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    auto c = coro_start([&](void*) {
        sender.call();
        afterCall = true;
    });
    sender.success();
    ASSERT_TRUE(afterCall);
    coro_close_all();
}

CASE(TestCoroCall5) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    int j = 0;
    auto c = coro_start([&](void*) {
        sender.call(1, j);
        afterCall = true;
    });
    sender.receive(2);
    ASSERT_TRUE(j == 2);
    ASSERT_TRUE(afterCall);
    coro_close_all();
}

CASE(TestCoroCall6) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    int j = 0;
    auto c = coro_start([&](void*) {
        sender.call(1, j, 100);
        afterCall = true;
    });
    sender.receive(2);
    ASSERT_TRUE(j == 2);
    ASSERT_TRUE(afterCall);
    coro_close_all();
}

CASE(TestCoroCall7) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    bool timeout = false;
    int j = 0;
    auto c = coro_start([&](void*) {
        timeout = (sender.call(1, j, 50) == coroutine::CallResult::TIMEOUT);
        afterCall = true;
    });
    std::this_thread::sleep_for(std::chrono::duration(std::chrono::milliseconds(51)));
    coro_sched();
    ASSERT_TRUE(j == 0);
    ASSERT_TRUE(afterCall);
    ASSERT_TRUE(timeout);
    coro_close_all();
}

CASE(TestCoroCall8) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    bool timeout = false;
    int j = 0;
    auto c = coro_start([&](void* ptr) {
        timeout = (sender.callNoParam(j, 50) == coroutine::CallResult::TIMEOUT);
        afterCall = true;
    });
    std::this_thread::sleep_for(std::chrono::duration(std::chrono::milliseconds(51)));
    coro_sched();
    ASSERT_TRUE(j == 0);
    ASSERT_TRUE(afterCall);
    ASSERT_TRUE(timeout);
    coro_close_all();
}

CASE(TestCoroCall9) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    bool timeout = false;
    int j = 0;
    auto c = coro_start([&](void* ptr) {
        timeout = (sender.call(j, 50) == coroutine::CallResult::TIMEOUT);
        afterCall = true;
    });
    std::this_thread::sleep_for(std::chrono::duration(std::chrono::milliseconds(51)));
    coro_sched();
    ASSERT_TRUE(j == 0);
    ASSERT_TRUE(afterCall);
    ASSERT_TRUE(timeout);
    coro_close_all();
}

CASE(TestCoroCall10) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    bool fail = false;
    int j = 0;
    auto c = coro_start([&](void* ptr) {
        fail = (sender.call(j) == coroutine::CallResult::FAIL);
        afterCall = true;
    });
    sender.error();
    ASSERT_TRUE(j == 0);
    ASSERT_TRUE(afterCall);
    ASSERT_TRUE(fail);
    coro_close_all();
}

CASE(TestCoroCall11) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    bool fail = false;
    auto c = coro_start([&](void* ptr) {
        fail = (sender.call() == coroutine::CallResult::FAIL);
        afterCall = true;
    });
    sender.error();
    ASSERT_TRUE(afterCall);
    ASSERT_TRUE(fail);
    coro_close_all();
}

CASE(TestCoroCall12) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    bool fail = false;
    int j = 0;
    auto c = coro_start([&](void* ptr) {
        fail = (sender.callNoParam(j) == coroutine::CallResult::FAIL);
        afterCall = true;
    });
    sender.error();
    ASSERT_TRUE(afterCall);
    ASSERT_TRUE(fail);
    ASSERT_TRUE(j == 0);
    coro_close_all();
}

CASE(TestCoroCall13) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    bool fail = false;
    int j = 0;
    auto c = coro_start([&](void* ptr) {
        fail = (sender.call(1, j) == coroutine::CallResult::FAIL);
        afterCall = true;
    });
    sender.error();
    ASSERT_TRUE(afterCall);
    ASSERT_TRUE(fail);
    ASSERT_TRUE(j == 0);
    coro_close_all();
}

CASE(TestCoroCall14) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    bool fail = false;
    int j = 0;
    auto c = coro_start([&](void*) {
        fail = (sender.call(1, j, 50) == coroutine::CallResult::FAIL);
        afterCall = true;
    });
    std::this_thread::sleep_for(std::chrono::duration(std::chrono::milliseconds(51)));
    sender.error();
    ASSERT_TRUE(j == 0);
    ASSERT_TRUE(afterCall);
    ASSERT_TRUE(fail);
    coro_close_all();
}

CASE(TestCoroCall15) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    bool fail = false;
    int j = 0;
    auto c = coro_start([&](void* ptr) {
        fail = (sender.callNoParam(j, 50) == coroutine::CallResult::FAIL);
        afterCall = true;
    });
    std::this_thread::sleep_for(std::chrono::duration(std::chrono::milliseconds(51)));
    sender.error();
    ASSERT_TRUE(j == 0);
    ASSERT_TRUE(afterCall);
    ASSERT_TRUE(fail);
    coro_close_all();
}

CASE(TestCoroCall16) {
    class SenderImpl : public coroutine::CoroProtocol {
    private:
        virtual bool send_(coroutine::CallID, const char*, std::size_t) override {
            return true;
        }
    };
    SenderImpl sender;
    coroutine::CallID id;
    bool afterCall = false;
    bool fail = false;
    int j = 0;
    auto c = coro_start([&](void* ptr) {
        fail = (sender.call(j, 50) == coroutine::CallResult::FAIL);
        afterCall = true;
    });
    std::this_thread::sleep_for(std::chrono::duration(std::chrono::milliseconds(51)));
    sender.error();
    ASSERT_TRUE(j == 0);
    ASSERT_TRUE(afterCall);
    ASSERT_TRUE(fail);
    coro_close_all();
}

CASE(TestSched1) {
    bool run = false;
    auto c = coro_spawn([&](void* ptr) {
        run = true;
    });
    coro_sched();
    coro_close_all();
}

CASE(TestSched2) {
    bool run = false;
    auto c = coro_spawn([&](void* ptr) {
        coro_sleep(50);
        run = true;
    });
    coro_sched();
    std::this_thread::sleep_for(std::chrono::duration(std::chrono::milliseconds(51)));
    coro_sched();
    coro_close_all();
}

CASE(TestResume1) {
    bool run1 = false;
    bool run2 = false;
    coroutine::CoroID c1;
    coroutine::CoroID c2;
    c1 = coro_spawn([&](void* ptr) {
        run1 = true;
        coro_yield(c2);
    });
    c2 = coro_spawn([&](void* ptr) {
        coro_resume(c1);
        run2 = true;
    });
    coro_resume(c2);
    ASSERT_TRUE(run1);
    ASSERT_TRUE(run2);
    coro_close_all();
}

CASE(TestResume2) {
    bool run1 = false;
    bool run2 = false;
    coroutine::CoroID c1;
    coroutine::CoroID c2;
    c1 = coro_spawn([&](void* ptr) {
        auto msg = coro_pop_wait();
        if (msg) {
            run1 = true;
        }
        coro_resume(c2);
    });
    c2 = coro_spawn([&](void* ptr) {
        coro_resume(c1, 1);
        run2 = true;
    });
    coro_resume(c2);
    ASSERT_TRUE(run1);
    ASSERT_TRUE(run2);
    coro_close_all();
}

CASE(TestQuit2) {
    coro_quit();
    coro_close_all();
}

FIXTURE_END(coro)
